# Define experiment parameters
year = "200506"
target_col = "has_occ" # 'white_collar', 'blue_collar', 'has_occ'
sample_weight_col = 'women_weight'
# Define resource utilization parameters
random_state = 42
n_jobs_clf = 16
n_jobs_cv = 4
cv_folds = 5
import numpy as np
np.random.seed(random_state)
import pandas as pd
pd.set_option('display.max_columns', 500)
import matplotlib.pylab as pl
from sklearn.metrics import classification_report
from sklearn.metrics import confusion_matrix
from sklearn.metrics import accuracy_score
from sklearn.metrics import f1_score
from sklearn.utils.class_weight import compute_class_weight
import lightgbm
from lightgbm import LGBMClassifier
from sklearn.model_selection import train_test_split
from sklearn.model_selection import RandomizedSearchCV, GridSearchCV
from sklearn.model_selection import StratifiedKFold
import shap
import pickle
from joblib import dump, load
# Load dataset
dataset = pd.read_csv(f"data/women_work_data_{year}.csv")
print("Loaded dataset: ", dataset.shape)
dataset.head()
# See distribution of target values
print("Target column distribution:\n", dataset[target_col].value_counts(dropna=False))
# Drop samples where the target is missing
dataset.dropna(axis=0, subset=[target_col, sample_weight_col], inplace=True)
print("Drop missing targets: ", dataset.shape)
# Drop samples where age < 21
dataset = dataset[dataset['age'] >= 21]
print("Drop under-21 samples: ", dataset.shape)
# See new distribution of target values
print("Target column distribution:\n", dataset[target_col].value_counts(dropna=False))
# Post-processing
# Group SC/ST castes together
dataset['caste'][dataset['caste'] == 'scheduled caste'] = 'sc/st'
dataset['caste'][dataset['caste'] == 'scheduled tribe'] = 'sc/st'
if year == "200506":
dataset['caste'][dataset['caste'] == '9'] = "don\'t know"
# Fix naming for General caste
dataset['caste'][dataset['caste'] == 'none of above'] = 'general'
if year == "201516":
# Convert wealth index from str to int values
wi_dict = {'poorest': 0, 'poorer': 1, 'middle': 2, 'richer': 3, 'richest': 4}
dataset['wealth_index'] = [wi_dict[wi] for wi in dataset['wealth_index']]
# Define feature columns
x_cols_categorical = ['state', 'hh_religion', 'caste']
x_cols_binary = ['urban', 'women_anemic', 'obese_female']
x_cols_numeric = ['age', 'years_edu', 'wealth_index', 'hh_members', 'no_children_below5', 'total_children', 'freq_tv']
x_cols = x_cols_categorical + x_cols_binary + x_cols_numeric
print("Feature columns:\n", x_cols)
# Drop samples with missing values in feature columns
dataset.dropna(axis=0, subset=x_cols, inplace=True)
print("Drop missing feature value rows: ", dataset.shape)
# Separate target column
targets = dataset[target_col]
# Separate sampling weight column
sample_weights = dataset[sample_weight_col]
# Drop columns which are not part of features
dataset.drop(columns=[col for col in dataset.columns if col not in x_cols], axis=1, inplace=True)
print("Drop extra columns: ", dataset.shape)
# Obtain one-hot encodings for the caste column
dataset = pd.get_dummies(dataset, columns=['caste'])
x_cols_categorical.remove('caste') # Remove 'caste' from categorical variables list
print("Caste to one-hot: ", dataset.shape)
dataset_display = dataset.copy()
print("Create copy for visualization: ", dataset_display.shape)
dataset_display.head()
# Obtain integer encodings for other categorical features
for col in x_cols_categorical:
dataset[col] = pd.factorize(dataset[col])[0]
print("Categoricals to int encodings: ", dataset.shape)
dataset.head()
# Create Training, Validation and Test sets
X_train, X_test, Y_train, Y_test, W_train, W_test = train_test_split(dataset, targets, sample_weights, test_size=0.05, random_state=random_state, stratify=targets)
# X_train, X_val, Y_train, Y_val, W_train, W_val = train_test_split(X_train, Y_train, W_train, test_size=0.1)
print("Training set: ", X_train.shape, Y_train.shape, W_train.shape)
# print("Validation set: ", X_val.shape, Y_val.shape, W_val.shape)
print("Test set: ", X_test.shape, Y_test.shape, W_test.shape)
train_cw = compute_class_weight("balanced", classes=np.unique(Y_train), y=Y_train)
print("Class weights: ", train_cw)
# # Define LightGBM Classifier
# model = LGBMClassifier(boosting_type='gbdt',
# feature_fraction=0.8,
# learning_rate=0.01,
# max_bins=64,
# max_depth=-1,
# min_child_weight=0.001,
# min_data_in_leaf=50,
# min_split_gain=0.0,
# num_iterations=1000,
# num_leaves=64,
# reg_alpha=0,
# reg_lambda=1,
# subsample_for_bin=200000,
# is_unbalance=True,
# random_state=random_state,
# n_jobs=n_jobs_clf,
# silent=True,
# importance_type='split')
# # Fit model on training set
# model.fit(X_train, Y_train, sample_weight=W_train.values,
# #categorical_feature=x_cols_categorical,
# categorical_feature=[])
# # Make predictions on Test set
# predictions = model.predict(X_test)
# print(accuracy_score(Y_test, predictions))
# print(f1_score(Y_test, predictions))
# print(confusion_matrix(Y_test, predictions))
# print(classification_report(Y_test, predictions))
# # Save trained model
# dump(model, f'models/{target_col}-{year}-model.joblib')
# del model
# # Define hyperparameter grid
# param_grid = {
# 'num_leaves': [8, 32, 64],
# 'min_data_in_leaf': [10, 20, 50],
# 'max_depth': [-1],
# 'learning_rate': [0.01, 0.1],
# 'num_iterations': [1000, 3000, 5000],
# 'subsample_for_bin': [200000],
# 'min_split_gain': [0.0],
# 'min_child_weight': [0.001],
# 'feature_fraction': [0.8, 1.0],
# 'reg_alpha': [0],
# 'reg_lambda': [0, 1],
# 'max_bin': [64, 128, 255]
# }
# # Define LightGBM Classifier
# clf = LGBMClassifier(boosting_type='gbdt',
# objective='binary',
# is_unbalance=True,
# random_state=random_state,
# n_jobs=n_jobs_clf,
# silent=True,
# importance_type='split')
# # Define K-fold cross validation splitter
# kfold = StratifiedKFold(n_splits=cv_folds, shuffle=True, random_state=random_state)
# # Perform grid search
# model = GridSearchCV(clf, param_grid=param_grid, scoring='f1', n_jobs=n_jobs_cv, cv=kfold, refit=True, verbose=3)
# model.fit(X_train, Y_train,
# sample_weight=W_train.values,
# #categorical_feature=x_cols_categorical,
# categorical_feature=[])
# print('\n All results:')
# print(model.cv_results_)
# print('\n Best estimator:')
# print(model.best_estimator_)
# print('\n Best hyperparameters:')
# print(model.best_params_)
# # Make predictions on Test set
# predictions = model.predict(X_test)
# print(accuracy_score(Y_test, predictions))
# print(f1_score(Y_test, predictions, average='micro'))
# print(confusion_matrix(Y_test, predictions))
# print(classification_report(Y_test, predictions))
# # Save trained model
# dump(model, f'models/{target_col}-{year}-gridsearch.joblib')
# del model
model = load(f'models/{target_col}-{year}-model.joblib')
# model = load(f'models/{target_col}-{year}-gridsearch.joblib').best_estimator_
# Sanity check: Make predictions on Test set
predictions = model.predict(X_test)
print(accuracy_score(Y_test, predictions))
print(f1_score(Y_test, predictions))
print(confusion_matrix(Y_test, predictions))
print(classification_report(Y_test, predictions))
# Overfitting check: Make predictions on Train set
predictions = model.predict(X_train)
print(accuracy_score(Y_train, predictions))
print(f1_score(Y_train, predictions))
print(confusion_matrix(Y_train, predictions))
print(classification_report(Y_train, predictions))
Note that these plot just explain how the XGBoost model works, not nessecarily how reality works. Since the XGBoost model is trained from observational data, it is not nessecarily a causal model, and so just because changing a factor makes the model's prediction of winning go up, does not always mean it will raise your actual chances.
# print the JS visualization code to the notebook
shap.initjs()
If consistency fails to hold, then we can’t compare the attributed feature importances between any two models, because then having a higher assigned attribution doesn’t mean the model actually relies more on that feature.
If accuracy fails to hold then we don’t know how the attributions of each feature combine to represent the output of the whole model. We can’t just normalize the attributions after the method is done since this might break the consistency of the method.
explainer = shap.TreeExplainer(model)
# shap_values = explainer.shap_values(dataset)
shap_values = pickle.load(open(f'res/{target_col}-{year}-shapvals.obj', 'rb'))
# Visualize a single prediction
shap.force_plot(explainer.expected_value, shap_values[0,:], dataset_display.iloc[0,:])
The above explanation shows features each contributing to push the model output from the base value (the average model output over the training dataset we passed) to the model output. Features pushing the prediction higher are shown in red, those pushing the prediction lower are in blue.
If we take many explanations such as the one shown above, rotate them 90 degrees, and then stack them horizontally, we can see explanations for an entire dataset (in the notebook this plot is interactive):
# Visualize many predictions
subsample = np.random.choice(len(dataset), 1000) # Take random sub-sample
shap.force_plot(explainer.expected_value, shap_values[subsample,:], dataset_display.iloc[subsample,:])
for col, sv in zip(dataset.columns, np.abs(shap_values).mean(0)):
print(f"{col} - {sv}")
shap.summary_plot(shap_values, dataset, plot_type="bar")
The above figure shows the global mean(|Tree SHAP|) method applied to our model.
The x-axis is essentially the average magnitude change in model output when a feature is “hidden” from the model (for this model the output has log-odds units). “Hidden” means integrating the variable out of the model. Since the impact of hiding a feature changes depending on what other features are also hidden, Shapley values are used to enforce consistency and accuracy.
However, since we now have individualized explanations for every person in our dataset, to get an overview of which features are most important for a model we can plot the SHAP values of every feature for every sample. The plot below sorts features by the sum of SHAP value magnitudes over all samples, and uses SHAP values to show the distribution of the impacts each feature has on the model output. The color represents the feature value (red high, blue low):
shap.summary_plot(shap_values, dataset_display)
How to use this: We can make analysis similar to the blog post for interpretting our models.
Next, to understand how a single feature effects the output of the model we can plot the SHAP value of that feature vs. the value of the feature for all the examples in a dataset. SHAP dependence plots show the effect of a single feature across the whole dataset. They plot a feature's value vs. the SHAP value of that feature across many samples.
SHAP dependence plots are similar to partial dependence plots, but account for the interaction effects present in the features, and are only defined in regions of the input space supported by data. The vertical dispersion of SHAP values at a single feature value is driven by interaction effects, and another feature is chosen for coloring to highlight possible interactions. One the benefits of SHAP dependence plots over traditional partial dependence plots is this ability to distigush between between models with and without interaction terms. In other words, SHAP dependence plots give an idea of the magnitude of the interaction terms through the vertical variance of the scatter plot at a given feature value.
Good example of using Dependency Plots: https://slundberg.github.io/shap/notebooks/League%20of%20Legends%20Win%20Prediction%20with%20XGBoost.html
# Define pairs of features and interaction indices for dependence plots
pairs = [('age', 'age'),
('age', 'urban'),
('age', 'caste_sc/st'),
('age', 'caste_general'),
('age', 'wealth_index'),
('age', 'years_edu'),
('age', 'no_children_below5'),
('age', 'total_children'),
('hh_religion', 'age'),
('state', 'age')]
# Dependence plots between pairs
for col_name, int_col_name in pairs:
print(f"\nFeature: {col_name}, Interaction Feature: {int_col_name}")
shap.dependence_plot(col_name, shap_values, dataset, display_features=dataset_display, interaction_index=int_col_name)
# Define pairs of features and interaction indices for dependence plots
pairs = [('wealth_index', 'wealth_index'),
('wealth_index', 'age'),
('wealth_index', 'urban'),
('wealth_index', 'caste_sc/st'),
('wealth_index', 'caste_general'),
('wealth_index', 'years_edu'),
('wealth_index', 'no_children_below5'),
('wealth_index', 'total_children'),
('hh_religion', 'wealth_index'),
('state', 'wealth_index')
]
# Dependence plots between pairs
for col_name, int_col_name in pairs:
print(f"\nFeature: {col_name}, Interaction Feature: {int_col_name}")
shap.dependence_plot(col_name, shap_values, dataset, display_features=dataset_display, interaction_index=int_col_name)
# Define pairs of features and interaction indices for dependence plots
pairs = [('years_edu', 'years_edu'),
('years_edu', 'age'),
('years_edu', 'urban'),
('years_edu', 'caste_sc/st'),
('years_edu', 'caste_general'),
('years_edu', 'wealth_index'),
('years_edu', 'no_children_below5'),
('years_edu', 'total_children'),
('hh_religion', 'years_edu'),
('state', 'years_edu')
]
# Dependence plots between pairs
for col_name, int_col_name in pairs:
print(f"\nFeature: {col_name}, Interaction Feature: {int_col_name}")
shap.dependence_plot(col_name, shap_values, dataset, display_features=dataset_display, interaction_index=int_col_name)
# Define pairs of features and interaction indices for dependence plots
pairs = [('caste_sc/st', 'caste_sc/st'),
('caste_sc/st', 'age'),
('caste_sc/st', 'urban'),
('caste_sc/st', 'years_edu'),
('caste_sc/st', 'wealth_index'),
('caste_sc/st', 'no_children_below5'),
('caste_sc/st', 'total_children'),
('hh_religion', 'caste_sc/st'),
('state', 'caste_sc/st')
]
# Dependence plots between pairs
for col_name, int_col_name in pairs:
print(f"\nFeature: {col_name}, Interaction Feature: {int_col_name}")
shap.dependence_plot(col_name, shap_values, dataset, display_features=dataset_display, interaction_index=int_col_name)
# Define pairs of features and interaction indices for dependence plots
pairs = [('caste_general', 'caste_general'),
('caste_general', 'age'),
('caste_general', 'urban'),
('caste_general', 'years_edu'),
('caste_general', 'wealth_index'),
('caste_general', 'no_children_below5'),
('caste_general', 'total_children'),
('hh_religion', 'caste_general'),
('state', 'caste_general')
]
# Dependence plots between pairs
for col_name, int_col_name in pairs:
print(f"\nFeature: {col_name}, Interaction Feature: {int_col_name}")
shap.dependence_plot(col_name, shap_values, dataset, display_features=dataset_display, interaction_index=int_col_name)
bins = [(21,25), (26,30), (31,35), (36,40), (41,45), (46,50)]
for low, high in bins:
# Sample dataset by age range
dataset_sample = dataset[(dataset.age > low) & (dataset.age <= high)]
dataset_display_sample = dataset_display[(dataset.age > low) & (dataset.age <= high)]
targets_sample = targets[(dataset.age > low) & (dataset.age <= high)]
shap_values_sample = shap_values[(dataset.age > low) & (dataset.age <= high)]
print("\nAge Range: {} - {} years".format(low, high))
print("Sample size: {}\n".format(len(dataset_sample)))
for col, sv in zip(dataset_sample.columns, np.abs(shap_values_sample).mean(0)):
print(f"{col} - {sv}")
# Summary plots
shap.summary_plot(shap_values_sample, dataset_sample, plot_type="bar")
shap.summary_plot(shap_values_sample, dataset_display_sample)
SHAP interaction values are a generalization of SHAP values to higher order interactions.
The model returns a matrix for every prediction, where the main effects are on the diagonal and the interaction effects are off-diagonal. The main effects are similar to the SHAP values you would get for a linear model, and the interaction effects captures all the higher-order interactions are divide them up among the pairwise interaction terms.
Note that the sum of the entire interaction matrix is the difference between the model's current output and expected output, and so the interaction effects on the off-diagonal are split in half (since there are two of each). When plotting interaction effects the SHAP package automatically multiplies the off-diagonal values by two to get the full interaction effect.
# Sample from dataset based on sample weights
dataset_ss = dataset.sample(10000, weights=sample_weights, random_state=random_state)
print(dataset_ss.shape)
dataset_display_ss = dataset_display.loc[dataset_ss.index]
print(dataset_display_ss.shape)
# Compute SHAP interaction values (time consuming)
# shap_interaction_values = explainer.shap_interaction_values(dataset_ss)
shap_interaction_values = pickle.load(open(f'res/{target_col}-{year}-shapints.obj', 'rb'))
shap.summary_plot(shap_interaction_values, dataset_display_ss, max_display=15)
tmp = np.abs(shap_interaction_values).sum(0)
for i in range(tmp.shape[0]):
tmp[i,i] = 0
inds = np.argsort(-tmp.sum(0))[:50]
tmp2 = tmp[inds,:][:,inds]
pl.figure(figsize=(12,12))
pl.imshow(tmp2)
pl.yticks(range(tmp2.shape[0]), dataset_ss.columns[inds], rotation=50.4, horizontalalignment="right")
pl.xticks(range(tmp2.shape[0]), dataset_ss.columns[inds], rotation=50.4, horizontalalignment="left")
pl.gca().xaxis.tick_top()
pl.show()
Running a dependence plot on the SHAP interaction values a allows us to separately observe the main effects and the interaction effects.
Below we plot the main effects for age and some of the interaction effects for age. It is informative to compare the main effects plot of age with the earlier SHAP value plot for age. The main effects plot has no vertical dispersion because the interaction effects are all captured in the off-diagonal terms.
Good example of how to infer interesting stuff from interaction values: https://slundberg.github.io/shap/notebooks/NHANES%20I%20Survival%20Model.html
shap.dependence_plot(
("age", "age"),
shap_interaction_values, dataset_ss, display_features=dataset_display_ss
)
Now we plot the interaction effects involving age (and other features after that). These effects capture all of the vertical dispersion that was present in the original SHAP plot but is missing from the main effects plot above.
# Define pairs of features and interaction indices for dependence plots
pairs = [('age', 'age'),
('age', 'urban'),
('age', 'caste_sc/st'),
('age', 'caste_general'),
('age', 'wealth_index'),
('age', 'years_edu'),
('age', 'no_children_below5'),
('age', 'total_children')]
# Dependence plots between pairs
for col_name, int_col_name in pairs:
print(f"\nFeature: {col_name}, Interaction Feature: {int_col_name}")
shap.dependence_plot(
(col_name, int_col_name),
shap_interaction_values, dataset_ss, display_features=dataset_display_ss
)
# Define pairs of features and interaction indices for dependence plots
pairs = [('wealth_index', 'wealth_index'),
('wealth_index', 'age'),
('wealth_index', 'urban'),
('wealth_index', 'caste_sc/st'),
('wealth_index', 'caste_general'),
('wealth_index', 'years_edu'),
('wealth_index', 'no_children_below5'),
('wealth_index', 'total_children')
]
# Dependence plots between pairs
for col_name, int_col_name in pairs:
print(f"\nFeature: {col_name}, Interaction Feature: {int_col_name}")
shap.dependence_plot(
(col_name, int_col_name),
shap_interaction_values, dataset_ss, display_features=dataset_display_ss
)
# Define pairs of features and interaction indices for dependence plots
pairs = [('years_edu', 'years_edu'),
('years_edu', 'age'),
('years_edu', 'urban'),
('years_edu', 'caste_sc/st'),
('years_edu', 'caste_general'),
('years_edu', 'wealth_index'),
('years_edu', 'no_children_below5'),
('years_edu', 'total_children')
]
# Dependence plots between pairs
for col_name, int_col_name in pairs:
print(f"\nFeature: {col_name}, Interaction Feature: {int_col_name}")
shap.dependence_plot(
(col_name, int_col_name),
shap_interaction_values, dataset_ss, display_features=dataset_display_ss
)
# Define pairs of features and interaction indices for dependence plots
pairs = [('caste_sc/st', 'caste_sc/st'),
('caste_sc/st', 'age'),
('caste_sc/st', 'urban'),
('caste_sc/st', 'years_edu'),
('caste_sc/st', 'wealth_index'),
('caste_sc/st', 'no_children_below5'),
('caste_sc/st', 'total_children')
]
# Dependence plots between pairs
for col_name, int_col_name in pairs:
print(f"\nFeature: {col_name}, Interaction Feature: {int_col_name}")
shap.dependence_plot(
(col_name, int_col_name),
shap_interaction_values, dataset_ss, display_features=dataset_display_ss
)
# Define pairs of features and interaction indices for dependence plots
pairs = [('caste_general', 'caste_general'),
('caste_general', 'age'),
('caste_general', 'urban'),
('caste_general', 'years_edu'),
('caste_general', 'wealth_index'),
('caste_general', 'no_children_below5'),
('caste_general', 'total_children'),
]
# Dependence plots between pairs
for col_name, int_col_name in pairs:
print(f"\nFeature: {col_name}, Interaction Feature: {int_col_name}")
shap.dependence_plot(
(col_name, int_col_name),
shap_interaction_values, dataset_ss, display_features=dataset_display_ss
)